from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter, RecursiveJsonSplitter
from langchain_text_splitters import Language
from langchain_community.document_loaders.parsers import LanguageParser
from langchain_community.document_loaders.generic import GenericLoader
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_community.document_loaders import JSONLoader
from langchain.schema import (
    HumanMessage, AIMessage, SystemMessage
)
import chainlit as cl
import utils
import torch
from langchain_huggingface import HuggingFaceEmbeddings
import os
import re
import gradio as gr
import requests
import json
import subprocess
import threading
import time
from langchain_community.document_loaders import DirectoryLoader, TextLoader


DEBUG = os.environ.get("DEBUG", "false") in ("true", "1", "yes")

model_name = "hf-models/glm-4-9b-chat"  # gitee-docs-lora

api_server_ready = False

GITEE_ACCESS_TOKEN = os.environ.get("GITEE_ACCESS_TOKEN", "")

if (not GITEE_ACCESS_TOKEN):
    print("GITEE_ACCESS_TOKEN 环境变量不存在")

base_url = "http://127.0.0.1:8000/v1/"
repo_path = "/repos"
git_clone_ai_doc_command = ["git", "clone",
                            f"https://oauth2:{GITEE_ACCESS_TOKEN}@gitee.com/gitee-ai/docs.git",  "--depth=1", "--single-branch", repo_path+"/ai.gitee.com"]

git_clone_gitee_doc_command = ["git", "clone",
                               f"https://oauth2:{GITEE_ACCESS_TOKEN}@gitee.com/oschina/gitee-help-center.git", "--depth=1", "--single-branch", repo_path+"/help.gitee.com"]


def clone_doc_repo():
    try:
        subprocess.run(
            git_clone_ai_doc_command, text=True)
        subprocess.run(
            git_clone_gitee_doc_command, text=True)
        # print("stdout:", git_process.stdout)
        # print("stderr:", git_process.stderr)
    except Exception as e:
        print("克隆仓库发生错误:", e)


db = ""
retriever = ""

bge_model_name = "hf-models/bge-m3"
bge_model_kwargs = {'device': 'cuda:1'}  # 'torch_dtype': torch.float16
bge_encode_kwargs = {'normalize_embeddings': False}
hfEmbeddings = HuggingFaceEmbeddings(
    model_name=bge_model_name,
    model_kwargs=bge_model_kwargs,
    encode_kwargs=bge_encode_kwargs,
)

# 多卡跑多 模型时, vllm GPU blocks: 0 https://github.com/vllm-project/vllm/issues/2248
torch.cuda.empty_cache()
print('向量模型准备完成')

# 读取 choice-gitee-docs.json 文件
# with open('./data/choice-gitee-docs.json', 'r', encoding='utf-8') as file:
#     choice_gitee_docs = json.load(file)

# # 提取所有 output
# choice_gitee_docs_output = [doc['output']
#                             for doc in choice_gitee_docs if 'output' in doc]
# # 修改文档中的链接。包含 slug 则由 ai 处理.

# choice_gitee_docs_output_docs = RecursiveJsonSplitter(
#     max_chunk_size=2000).create_documents(texts=[choice_gitee_docs_output])


# json_loader = JSONLoader(
#     file_path='./data/choice-gitee-docs.json',
#     jq_schema='.',
#     text_content=False)

# choice_gitee_docs_output_docs = json_loader.load()


def update_sources(documents):
    utils.load_network_data()
    dir_loader = DirectoryLoader(
        './data', glob="**/[!.]*", loader_cls=TextLoader, show_progress=True)
    dir_loader_document = dir_loader.load()
    print("./data 数据量", len(dir_loader_document))
    for doc in documents:
        source = doc.metadata.get('source')
        # /repos/ai.gitee.com/docs/xxx
        # /repos/help.gitee.com/docs/xxx
        if source and source.startswith(repo_path+'/'):
            # 将 /repos/ 改为 https://
            source = source.replace(repo_path+'/', 'https://', 1)
            # 去掉末尾的 .md
            if source.endswith('.md'):
                source = source[:-3]
            # 针对 /help.gitee.com 去掉 /docs
            if 'help.gitee.com/docs' in source:
                source = source.replace('/docs', '', 1)

        # 检查 page_content 中是否有 slug
        slug_match = re.search(r'slug: (/.+)', doc.page_content)
        if slug_match:
            slug = slug_match.group(1)
            doc.page_content = f'参考文档链接: https://help.gitee.com{slug}\n\n{doc.page_content}'
        elif source:
            # 拼接 source 到 page_content 最开始
            doc.page_content = f'参考文档链接: {source}\n\n{doc.page_content}'

        # 从 metadata 中移除 source
        if 'source' in doc.metadata:
            del doc.metadata['source']

    return dir_loader_document + documents


update_doc_db_ing = False


def update_doc_db():
    global update_doc_db_ing
    global db
    global retriever
    print('开始生成向量库')
    if (update_doc_db_ing):
        return
    update_doc_db_ing = True

    subprocess.run(["rm", "-rf", repo_path])
    clone_doc_repo()
    loader = GenericLoader.from_filesystem(
        repo_path,
        glob="**/*",
        suffixes=[".md", ".txt"],
        exclude=["**/non-utf8-encoding.py", ".git/**"],
        parser=LanguageParser(parser_threshold=0),
    )
    documents = loader.load()
    documents = update_sources(documents)
    python_splitter = RecursiveCharacterTextSplitter.from_language(
        # ValueError: Batch size 51030 exceeds maximum batch size 41666
        language=Language.MARKDOWN, chunk_size=1800, chunk_overlap=400
    )
    docs = python_splitter.split_documents(documents)
    print("召回内容", docs[5:10])
    print("docs 分块长度", len(docs))
    # db = Chroma.from_documents(docs, OpenAIEmbeddings(disallowed_special=(), api_key="EMPTY", base_url=base_url, model="hf-models/bge-m3", timeout=300,
    #                                                   tiktoken_enabled=False, show_progress_bar=True, chunk_size=8192), persist_directory="./chroma_db")  # 非 OpenAI 实现，tiktoken_enabled 必须设置为 False
    db = Chroma.from_documents(docs, hfEmbeddings)
    retriever = db.as_retriever(
        # search_type="mmr",  # 值为 'similarity', 'similarity_score_threshold', 'mmr')
        search_type="similarity",
        search_kwargs={"k": 13}
    )
    update_doc_db_ing = False
    print('向量库已生成')
    torch.cuda.empty_cache()

# print(documents[100:500])


update_doc_db()
# chunk_size 块字符长度, 召回的每一项中字符串的长度, 代码文件需要比较大才能召回文件内容完整。300 行代码可能有六千字符
#  chunk_overlap (块重叠) 指相邻块之间的重叠部分，确保在块分块后， 块之间的内容不会被忽略


llm = ChatOpenAI(model=model_name, api_key="EMPTY", base_url=base_url,
                 # glm4 [151329, 151336, 151338] qwen2[151643, 151644, 151645]
                 callbacks=[StreamingStdOutCallbackHandler()],
                 streaming=True, temperature=0.2, presence_penalty=1.2, top_p=0.95, extra_body={"stop_token_ids": [151329, 151336, 151338]})


@cl.on_chat_start
async def start_chat():
    # cl.user_session.set(
    #     "message_history",
    #     [{"role": "system", "content": "xxx"}],
    # )
    await cl.Message(content="哈喽！我是马建仓，你在使用 Gitee AI 过程中有任何问题都可以找我！\nAI 输出，仅用于文档参考，不代表官方观点。").send()


@cl.on_message
async def main(message: cl.Message):
    gen_time_start = time.time()
    global api_server_ready
    global db
    global retriever

    if (not api_server_ready):
        api_server_ready = utils.is_port_open(base_url)
        return await cl.Message(content="API 服务正在启动中，请重试...").send()
    if (not message.content):
        return await cl.Message(content="请输入你的问题").send()
    if len(message.content) > 1000:
        return await cl.Message(content="输入长度不能超过 1000 哦！").send()
    print("\ninput:", message.content)
    if (not db):
        update_doc_db()
        # raise gr.Info("正在更新向量数据库，请稍后，过程大约持续三分钟")
    search_res = str(retriever.invoke(message.content))
    if (DEBUG):
        print("搜索结果:", search_res)
    system_message = {"role": "system", "content": f"""
- 【你十分幽默风趣活泼，名叫马建仓，是 Gitee 吉祥物，现在在给 Gitee 打工，是一名金牌客服】。
- 你的回答参考 markdown 文档内容，紧扣用户问题，使用 markdown 格式回答。
- 注意区分用户问题, 如有必要, Gitee 和 Gitee AI 分开回答, 上下文提供了 Gitee 和 Gitee AI 两种文档。
- 用户提问模糊不清时，如有必要，你可以继续询问用户，细化问题。
- 回答时，确保：
    1. 文档中有 https 链接（图片、链接等），如有必要，请提供完整，可以使用 markdown 渲染，相对路径和不相关图片请勿提供。
    2. 如有必要，回答末尾可提供前少量参考文档 https:// 链接, 方便用户了解更多, 【注意: 请勿编造链接，请勿总结混合链接，没有就不要提供】。
    3. 遵守中国法律，不回复任何敏感、违法、违反道德的问题。
    4. 你的知识和上下文文档冲突时，以文档为准。
    5. 如果你不知道答案，请勿捏造假内容。
- markdown 文档内容为: {search_res}
"""}

    message_history = cl.user_session.get(
        "message_history") or []
    message_history.append({"role": "user", "content": message.content})

    ai_msg = cl.Message(content="")
    await ai_msg.send()
    # 每次都在最开始插入 system_message, 但不保存到用户会话历史消息中
    # message_history.insert(0, system_message)  # insert 改变原数组, 返回 None
    stream = llm.stream([system_message] + message_history)
    if (DEBUG):
        print("history", str(message_history))
    full_answer = ""
    for part in stream:
        if content := part.content or "":
            full_answer += content
            if (len(full_answer) == 1):
                gen_time_end = time.time()
                print("首字消耗时间:", gen_time_end - gen_time_start)
            await ai_msg.stream_token(content)
    message_history.append({"role": "assistant", "content": ai_msg.content})
    cl.user_session.set("message_history", message_history)
    await ai_msg.update()
